from tensorflow import keras
from tensorflow.keras import layers
import pathlib
from tensorflow.keras.utils import image_dataset_from_directory
import pandas as pd
import pathlib
from pathlib import Path
import numpy as np
import pandas as pd
# plotting modules
from matplotlib import pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
import plotly as plotly
plotly.offline.init_notebook_mode()
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow as tf
from keras.utils import to_categorical
from keras.models import load_model
import os
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.image import imread
import plotly.graph_objects as go
from tensorflow.keras.models import Sequential
from keras.callbacks import ModelCheckpoint
from tensorflow.keras.layers import Dense
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report, precision_recall_curve, ConfusionMatrixDisplay
tf.config.list_physical_devices('GPU')
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Breast cancer is one of the leading causes of cancer-related deaths worldwide with the *Invasive Ductal Carcinoma (IDC)* being the most common type. Early detection is crucial for effective treatment and improved patient outcomes. However, current detection methods often lack accuracy and reliability. Our project aims to improve breast cancer detection using deep learning techniques.
The original dataset - https://www.ncbi.nlm.nih.gov/pubmed/27563488 and http://spie.org/Publications/Proceedings/Paper/10.1117/12.2043872 consisted of 162 whole mount slide images of Breast Cancer (BCa) specimens scanned at 40x. From that, 277,524 patches of size 50 x 50 were extracted (198,738 IDC negative and 78,786 IDC positive). Each patch’s file name is of the format: u_xX_yY_classC.png — > example 10253_idx5_x1351_y1101_class0.png . Where u is the patient ID (10253_idx5), X is the x-coordinate of where this patch was cropped from, Y is the y-coordinate of where this patch was cropped from, and C indicates the class where 0 is non-IDC and 1 is IDC.
data_folder = pathlib.Path("../../../../../Downloads/hispathology")
len(list(data_folder.glob('*')))
280
patient_data_folder = data_folder / 'IDC_regular_ps50_idx5'
Here you see we have 279 patients.
Each folder represents each patient and their images.
len(list(patient_data_folder.glob('*')))
279
patient_folders = [x for x in patient_data_folder.glob('*') if x.is_dir()]
patient_folders = [x.name for x in patient_folders]
patient_folders[:10]
['10253', '10254', '10255', '10256', '10257', '10258', '10259', '10260', '10261', '10262']
total_images = 0
for patient_folder in patient_folders:
patient_folder_path = patient_data_folder / patient_folder
subfolders = ['0', '1'] # List of subfolders you're interested in
for subfolder in subfolders:
subfolder_path = patient_folder_path / subfolder
images = list(subfolder_path.glob('*'))
total_images += len(images)
print(f'Patient folder {patient_folder}/{subfolder} contains {len(images)} images.')
Patient folder 10253/0 contains 479 images. Patient folder 10253/1 contains 70 images. Patient folder 10254/0 contains 772 images. Patient folder 10254/1 contains 76 images. Patient folder 10255/0 contains 181 images. Patient folder 10255/1 contains 91 images. Patient folder 10256/0 contains 351 images. Patient folder 10256/1 contains 117 images. Patient folder 10257/0 contains 427 images. Patient folder 10257/1 contains 208 images. Patient folder 10258/0 contains 422 images. Patient folder 10258/1 contains 108 images. Patient folder 10259/0 contains 1434 images. Patient folder 10259/1 contains 31 images. Patient folder 10260/0 contains 928 images. Patient folder 10260/1 contains 361 images. Patient folder 10261/0 contains 590 images. Patient folder 10261/1 contains 56 images. Patient folder 10262/0 contains 1053 images. Patient folder 10262/1 contains 754 images. Patient folder 10264/0 contains 617 images. Patient folder 10264/1 contains 587 images. Patient folder 10268/0 contains 2086 images. Patient folder 10268/1 contains 23 images. Patient folder 10269/0 contains 904 images. Patient folder 10269/1 contains 250 images. Patient folder 10272/0 contains 2150 images. Patient folder 10272/1 contains 25 images. Patient folder 10273/0 contains 811 images. Patient folder 10273/1 contains 1211 images. Patient folder 10274/0 contains 659 images. Patient folder 10274/1 contains 219 images. Patient folder 10275/0 contains 297 images. Patient folder 10275/1 contains 760 images. Patient folder 10276/0 contains 591 images. Patient folder 10276/1 contains 348 images. Patient folder 10277/0 contains 785 images. Patient folder 10277/1 contains 170 images. Patient folder 10278/0 contains 1068 images. Patient folder 10278/1 contains 91 images. Patient folder 10279/0 contains 1267 images. Patient folder 10279/1 contains 427 images. Patient folder 10282/0 contains 1835 images. Patient folder 10282/1 contains 198 images. Patient folder 10285/0 contains 1011 images. Patient folder 10285/1 contains 222 images. Patient folder 10286/0 contains 458 images. Patient folder 10286/1 contains 162 images. Patient folder 10288/0 contains 2231 images. Patient folder 10288/1 contains 47 images. Patient folder 10290/0 contains 1891 images. Patient folder 10290/1 contains 140 images. Patient folder 10291/0 contains 999 images. Patient folder 10291/1 contains 213 images. Patient folder 10292/0 contains 998 images. Patient folder 10292/1 contains 487 images. Patient folder 10293/0 contains 649 images. Patient folder 10293/1 contains 221 images. Patient folder 10295/0 contains 761 images. Patient folder 10295/1 contains 134 images. Patient folder 10299/0 contains 759 images. Patient folder 10299/1 contains 1347 images. Patient folder 10300/0 contains 1464 images. Patient folder 10300/1 contains 29 images. Patient folder 10301/0 contains 1074 images. Patient folder 10301/1 contains 342 images. Patient folder 10302/0 contains 598 images. Patient folder 10302/1 contains 1309 images. Patient folder 10303/0 contains 579 images. Patient folder 10303/1 contains 773 images. Patient folder 10304/0 contains 779 images. Patient folder 10304/1 contains 101 images. Patient folder 10305/0 contains 1802 images. Patient folder 10305/1 contains 19 images. Patient folder 10306/0 contains 751 images. Patient folder 10306/1 contains 273 images. Patient folder 10307/0 contains 915 images. Patient folder 10307/1 contains 80 images. Patient folder 10308/0 contains 1383 images. Patient folder 10308/1 contains 895 images. Patient folder 12241/0 contains 37 images. Patient folder 12241/1 contains 115 images. Patient folder 12242/0 contains 668 images. Patient folder 12242/1 contains 429 images. Patient folder 12626/0 contains 1088 images. Patient folder 12626/1 contains 254 images. Patient folder 12748/0 contains 168 images. Patient folder 12748/1 contains 198 images. Patient folder 12749/0 contains 1199 images. Patient folder 12749/1 contains 563 images. Patient folder 12750/0 contains 1413 images. Patient folder 12750/1 contains 21 images. Patient folder 12751/0 contains 849 images. Patient folder 12751/1 contains 967 images. Patient folder 12752/0 contains 464 images. Patient folder 12752/1 contains 635 images. Patient folder 12810/0 contains 914 images. Patient folder 12810/1 contains 252 images. Patient folder 12811/0 contains 125 images. Patient folder 12811/1 contains 126 images. Patient folder 12817/0 contains 362 images. Patient folder 12817/1 contains 572 images. Patient folder 12818/0 contains 666 images. Patient folder 12818/1 contains 945 images. Patient folder 12819/0 contains 1404 images. Patient folder 12819/1 contains 223 images. Patient folder 12820/0 contains 753 images. Patient folder 12820/1 contains 369 images. Patient folder 12821/0 contains 1066 images. Patient folder 12821/1 contains 319 images. Patient folder 12822/0 contains 490 images. Patient folder 12822/1 contains 271 images. Patient folder 12823/0 contains 556 images. Patient folder 12823/1 contains 447 images. Patient folder 12824/0 contains 597 images. Patient folder 12824/1 contains 110 images. Patient folder 12826/0 contains 963 images. Patient folder 12826/1 contains 174 images. Patient folder 12867/0 contains 851 images. Patient folder 12867/1 contains 575 images. Patient folder 12868/0 contains 500 images. Patient folder 12868/1 contains 361 images. Patient folder 12869/0 contains 778 images. Patient folder 12869/1 contains 18 images. Patient folder 12870/0 contains 788 images. Patient folder 12870/1 contains 41 images. Patient folder 12871/0 contains 146 images. Patient folder 12871/1 contains 36 images. Patient folder 12872/0 contains 711 images. Patient folder 12872/1 contains 69 images. Patient folder 12873/0 contains 49 images. Patient folder 12873/1 contains 232 images. Patient folder 12875/0 contains 331 images. Patient folder 12875/1 contains 43 images. Patient folder 12876/0 contains 50 images. Patient folder 12876/1 contains 105 images. Patient folder 12877/0 contains 272 images. Patient folder 12877/1 contains 33 images. Patient folder 12878/0 contains 1289 images. Patient folder 12878/1 contains 185 images. Patient folder 12879/0 contains 272 images. Patient folder 12879/1 contains 144 images. Patient folder 12880/0 contains 788 images. Patient folder 12880/1 contains 1147 images. Patient folder 12881/0 contains 115 images. Patient folder 12881/1 contains 158 images. Patient folder 12882/0 contains 238 images. Patient folder 12882/1 contains 154 images. Patient folder 12883/0 contains 276 images. Patient folder 12883/1 contains 73 images. Patient folder 12884/0 contains 533 images. Patient folder 12884/1 contains 236 images. Patient folder 12886/0 contains 240 images. Patient folder 12886/1 contains 287 images. Patient folder 12890/0 contains 1313 images. Patient folder 12890/1 contains 158 images. Patient folder 12891/0 contains 442 images. Patient folder 12891/1 contains 172 images. Patient folder 12892/0 contains 133 images. Patient folder 12892/1 contains 93 images. Patient folder 12893/0 contains 216 images. Patient folder 12893/1 contains 482 images. Patient folder 12894/0 contains 1066 images. Patient folder 12894/1 contains 650 images. Patient folder 12895/0 contains 939 images. Patient folder 12895/1 contains 741 images. Patient folder 12896/0 contains 424 images. Patient folder 12896/1 contains 83 images. Patient folder 12897/0 contains 565 images. Patient folder 12897/1 contains 296 images. Patient folder 12898/0 contains 372 images. Patient folder 12898/1 contains 208 images. Patient folder 12900/0 contains 573 images. Patient folder 12900/1 contains 450 images. Patient folder 12901/0 contains 578 images. Patient folder 12901/1 contains 230 images. Patient folder 12905/0 contains 1193 images. Patient folder 12905/1 contains 21 images. Patient folder 12906/0 contains 816 images. Patient folder 12906/1 contains 887 images. Patient folder 12907/0 contains 546 images. Patient folder 12907/1 contains 504 images. Patient folder 12908/0 contains 778 images. Patient folder 12908/1 contains 262 images. Patient folder 12909/0 contains 283 images. Patient folder 12909/1 contains 514 images. Patient folder 12910/0 contains 1496 images. Patient folder 12910/1 contains 222 images. Patient folder 12911/0 contains 1041 images. Patient folder 12911/1 contains 201 images. Patient folder 12929/0 contains 90 images. Patient folder 12929/1 contains 70 images. Patient folder 12930/0 contains 835 images. Patient folder 12930/1 contains 165 images. Patient folder 12931/0 contains 477 images. Patient folder 12931/1 contains 130 images. Patient folder 12932/0 contains 433 images. Patient folder 12932/1 contains 304 images. Patient folder 12933/0 contains 125 images. Patient folder 12933/1 contains 30 images. Patient folder 12934/0 contains 1500 images. Patient folder 12934/1 contains 504 images. Patient folder 12935/0 contains 611 images. Patient folder 12935/1 contains 615 images. Patient folder 12947/0 contains 378 images. Patient folder 12947/1 contains 452 images. Patient folder 12948/0 contains 80 images. Patient folder 12948/1 contains 87 images. Patient folder 12949/0 contains 344 images. Patient folder 12949/1 contains 464 images. Patient folder 12951/0 contains 808 images. Patient folder 12951/1 contains 330 images. Patient folder 12954/0 contains 1945 images. Patient folder 12954/1 contains 64 images. Patient folder 12955/0 contains 809 images. Patient folder 12955/1 contains 253 images. Patient folder 13018/0 contains 175 images. Patient folder 13018/1 contains 128 images. Patient folder 13019/0 contains 1069 images. Patient folder 13019/1 contains 441 images. Patient folder 13020/0 contains 336 images. Patient folder 13020/1 contains 50 images. Patient folder 13021/0 contains 1089 images. Patient folder 13021/1 contains 108 images. Patient folder 13022/0 contains 1056 images. Patient folder 13022/1 contains 286 images. Patient folder 13023/0 contains 112 images. Patient folder 13023/1 contains 115 images. Patient folder 13024/0 contains 624 images. Patient folder 13024/1 contains 186 images. Patient folder 13025/0 contains 465 images. Patient folder 13025/1 contains 296 images. Patient folder 13106/0 contains 979 images. Patient folder 13106/1 contains 155 images. Patient folder 13400/0 contains 1299 images. Patient folder 13400/1 contains 64 images. Patient folder 13401/0 contains 323 images. Patient folder 13401/1 contains 244 images. Patient folder 13402/0 contains 292 images. Patient folder 13402/1 contains 519 images. Patient folder 13403/0 contains 201 images. Patient folder 13403/1 contains 111 images. Patient folder 13404/0 contains 379 images. Patient folder 13404/1 contains 178 images. Patient folder 13458/0 contains 394 images. Patient folder 13458/1 contains 49 images. Patient folder 13459/0 contains 802 images. Patient folder 13459/1 contains 224 images. Patient folder 13460/0 contains 623 images. Patient folder 13460/1 contains 46 images. Patient folder 13461/0 contains 594 images. Patient folder 13461/1 contains 70 images. Patient folder 13462/0 contains 1028 images. Patient folder 13462/1 contains 726 images. Patient folder 13591/0 contains 907 images. Patient folder 13591/1 contains 128 images. Patient folder 13613/0 contains 827 images. Patient folder 13613/1 contains 630 images. Patient folder 13616/0 contains 656 images. Patient folder 13616/1 contains 701 images. Patient folder 13617/0 contains 299 images. Patient folder 13617/1 contains 56 images. Patient folder 13666/0 contains 377 images. Patient folder 13666/1 contains 30 images. Patient folder 13687/0 contains 303 images. Patient folder 13687/1 contains 151 images. Patient folder 13688/0 contains 215 images. Patient folder 13688/1 contains 127 images. Patient folder 13689/0 contains 510 images. Patient folder 13689/1 contains 76 images. Patient folder 13691/0 contains 927 images. Patient folder 13691/1 contains 264 images. Patient folder 13692/0 contains 272 images. Patient folder 13692/1 contains 335 images. Patient folder 13693/0 contains 1935 images. Patient folder 13693/1 contains 460 images. Patient folder 13694/0 contains 360 images. Patient folder 13694/1 contains 870 images. Patient folder 13916/0 contains 1268 images. Patient folder 13916/1 contains 365 images. Patient folder 14078/0 contains 100 images. Patient folder 14078/1 contains 121 images. Patient folder 14079/0 contains 435 images. Patient folder 14079/1 contains 455 images. Patient folder 14081/0 contains 117 images. Patient folder 14081/1 contains 209 images. Patient folder 14082/0 contains 281 images. Patient folder 14082/1 contains 197 images. Patient folder 14153/0 contains 579 images. Patient folder 14153/1 contains 210 images. Patient folder 14154/0 contains 691 images. Patient folder 14154/1 contains 829 images. Patient folder 14155/0 contains 671 images. Patient folder 14155/1 contains 1206 images. Patient folder 14156/0 contains 1201 images. Patient folder 14156/1 contains 197 images. Patient folder 14157/0 contains 990 images. Patient folder 14157/1 contains 488 images. Patient folder 14188/0 contains 586 images. Patient folder 14188/1 contains 123 images. Patient folder 14189/0 contains 581 images. Patient folder 14189/1 contains 421 images. Patient folder 14190/0 contains 458 images. Patient folder 14190/1 contains 491 images. Patient folder 14191/0 contains 723 images. Patient folder 14191/1 contains 617 images. Patient folder 14192/0 contains 837 images. Patient folder 14192/1 contains 195 images. Patient folder 14209/0 contains 33 images. Patient folder 14209/1 contains 309 images. Patient folder 14210/0 contains 469 images. Patient folder 14210/1 contains 104 images. Patient folder 14211/0 contains 1287 images. Patient folder 14211/1 contains 809 images. Patient folder 14212/0 contains 167 images. Patient folder 14212/1 contains 44 images. Patient folder 14213/0 contains 169 images. Patient folder 14213/1 contains 253 images. Patient folder 14304/0 contains 410 images. Patient folder 14304/1 contains 432 images. Patient folder 14305/0 contains 714 images. Patient folder 14305/1 contains 272 images. Patient folder 14306/0 contains 264 images. Patient folder 14306/1 contains 167 images. Patient folder 14321/0 contains 426 images. Patient folder 14321/1 contains 195 images. Patient folder 15471/0 contains 448 images. Patient folder 15471/1 contains 86 images. Patient folder 15472/0 contains 1490 images. Patient folder 15472/1 contains 214 images. Patient folder 15473/0 contains 553 images. Patient folder 15473/1 contains 885 images. Patient folder 15510/0 contains 705 images. Patient folder 15510/1 contains 356 images. Patient folder 15512/0 contains 79 images. Patient folder 15512/1 contains 143 images. Patient folder 15513/0 contains 815 images. Patient folder 15513/1 contains 54 images. Patient folder 15514/0 contains 197 images. Patient folder 15514/1 contains 441 images. Patient folder 15515/0 contains 1051 images. Patient folder 15515/1 contains 111 images. Patient folder 15516/0 contains 1016 images. Patient folder 15516/1 contains 275 images. Patient folder 15632/0 contains 373 images. Patient folder 15632/1 contains 120 images. Patient folder 15633/0 contains 114 images. Patient folder 15633/1 contains 337 images. Patient folder 15634/0 contains 439 images. Patient folder 15634/1 contains 370 images. Patient folder 15839/0 contains 105 images. Patient folder 15839/1 contains 134 images. Patient folder 15840/0 contains 862 images. Patient folder 15840/1 contains 244 images. Patient folder 15902/0 contains 706 images. Patient folder 15902/1 contains 461 images. Patient folder 15903/0 contains 418 images. Patient folder 15903/1 contains 621 images. Patient folder 16014/0 contains 497 images. Patient folder 16014/1 contains 209 images. Patient folder 16085/0 contains 1913 images. Patient folder 16085/1 contains 24 images. Patient folder 16165/0 contains 937 images. Patient folder 16165/1 contains 1174 images. Patient folder 16166/0 contains 615 images. Patient folder 16166/1 contains 675 images. Patient folder 16167/0 contains 96 images. Patient folder 16167/1 contains 96 images. Patient folder 16531/0 contains 191 images. Patient folder 16531/1 contains 58 images. Patient folder 16532/0 contains 339 images. Patient folder 16532/1 contains 128 images. Patient folder 16533/0 contains 240 images. Patient folder 16533/1 contains 127 images. Patient folder 16534/0 contains 21 images. Patient folder 16534/1 contains 42 images. Patient folder 16550/0 contains 2115 images. Patient folder 16550/1 contains 187 images. Patient folder 16551/0 contains 1899 images. Patient folder 16551/1 contains 284 images. Patient folder 16552/0 contains 150 images. Patient folder 16552/1 contains 37 images. Patient folder 16553/0 contains 327 images. Patient folder 16553/1 contains 353 images. Patient folder 16554/0 contains 269 images. Patient folder 16554/1 contains 448 images. Patient folder 16555/0 contains 315 images. Patient folder 16555/1 contains 85 images. Patient folder 16568/0 contains 545 images. Patient folder 16568/1 contains 283 images. Patient folder 16569/0 contains 302 images. Patient folder 16569/1 contains 35 images. Patient folder 16570/0 contains 375 images. Patient folder 16570/1 contains 542 images. Patient folder 16895/0 contains 115 images. Patient folder 16895/1 contains 36 images. Patient folder 16896/0 contains 1017 images. Patient folder 16896/1 contains 110 images. Patient folder 8863/0 contains 772 images. Patient folder 8863/1 contains 207 images. Patient folder 8864/0 contains 805 images. Patient folder 8864/1 contains 328 images. Patient folder 8865/0 contains 657 images. Patient folder 8865/1 contains 55 images. Patient folder 8867/0 contains 1480 images. Patient folder 8867/1 contains 162 images. Patient folder 8913/0 contains 873 images. Patient folder 8913/1 contains 82 images. Patient folder 8914/0 contains 978 images. Patient folder 8914/1 contains 75 images. Patient folder 8916/0 contains 60 images. Patient folder 8916/1 contains 111 images. Patient folder 8917/0 contains 578 images. Patient folder 8917/1 contains 397 images. Patient folder 8918/0 contains 1421 images. Patient folder 8918/1 contains 120 images. Patient folder 8950/0 contains 420 images. Patient folder 8950/1 contains 190 images. Patient folder 8951/0 contains 433 images. Patient folder 8951/1 contains 180 images. Patient folder 8955/0 contains 314 images. Patient folder 8955/1 contains 181 images. Patient folder 8956/0 contains 1485 images. Patient folder 8956/1 contains 340 images. Patient folder 8957/0 contains 28 images. Patient folder 8957/1 contains 83 images. Patient folder 8959/0 contains 152 images. Patient folder 8959/1 contains 204 images. Patient folder 8974/0 contains 1372 images. Patient folder 8974/1 contains 369 images. Patient folder 8975/0 contains 1379 images. Patient folder 8975/1 contains 833 images. Patient folder 8980/0 contains 487 images. Patient folder 8980/1 contains 209 images. Patient folder 8984/0 contains 962 images. Patient folder 8984/1 contains 156 images. Patient folder 9022/0 contains 418 images. Patient folder 9022/1 contains 99 images. Patient folder 9023/0 contains 583 images. Patient folder 9023/1 contains 288 images. Patient folder 9029/0 contains 1497 images. Patient folder 9029/1 contains 137 images. Patient folder 9035/0 contains 185 images. Patient folder 9035/1 contains 51 images. Patient folder 9036/0 contains 1276 images. Patient folder 9036/1 contains 30 images. Patient folder 9037/0 contains 924 images. Patient folder 9037/1 contains 188 images. Patient folder 9041/0 contains 857 images. Patient folder 9041/1 contains 178 images. Patient folder 9043/0 contains 276 images. Patient folder 9043/1 contains 560 images. Patient folder 9044/0 contains 112 images. Patient folder 9044/1 contains 46 images. Patient folder 9073/0 contains 771 images. Patient folder 9073/1 contains 63 images. Patient folder 9075/0 contains 1420 images. Patient folder 9075/1 contains 361 images. Patient folder 9076/0 contains 832 images. Patient folder 9076/1 contains 159 images. Patient folder 9077/0 contains 360 images. Patient folder 9077/1 contains 1263 images. Patient folder 9078/0 contains 1602 images. Patient folder 9078/1 contains 186 images. Patient folder 9081/0 contains 681 images. Patient folder 9081/1 contains 180 images. Patient folder 9083/0 contains 373 images. Patient folder 9083/1 contains 198 images. Patient folder 9123/0 contains 1427 images. Patient folder 9123/1 contains 161 images. Patient folder 9124/0 contains 175 images. Patient folder 9124/1 contains 290 images. Patient folder 9125/0 contains 369 images. Patient folder 9125/1 contains 239 images. Patient folder 9126/0 contains 1147 images. Patient folder 9126/1 contains 447 images. Patient folder 9135/0 contains 604 images. Patient folder 9135/1 contains 111 images. Patient folder 9173/0 contains 1020 images. Patient folder 9173/1 contains 485 images. Patient folder 9174/0 contains 190 images. Patient folder 9174/1 contains 29 images. Patient folder 9175/0 contains 108 images. Patient folder 9175/1 contains 10 images. Patient folder 9176/0 contains 636 images. Patient folder 9176/1 contains 409 images. Patient folder 9177/0 contains 842 images. Patient folder 9177/1 contains 261 images. Patient folder 9178/0 contains 1283 images. Patient folder 9178/1 contains 160 images. Patient folder 9181/0 contains 915 images. Patient folder 9181/1 contains 161 images. Patient folder 9225/0 contains 1454 images. Patient folder 9225/1 contains 89 images. Patient folder 9226/0 contains 817 images. Patient folder 9226/1 contains 421 images. Patient folder 9227/0 contains 687 images. Patient folder 9227/1 contains 127 images. Patient folder 9228/0 contains 412 images. Patient folder 9228/1 contains 71 images. Patient folder 9250/0 contains 636 images. Patient folder 9250/1 contains 515 images. Patient folder 9254/0 contains 999 images. Patient folder 9254/1 contains 173 images. Patient folder 9255/0 contains 838 images. Patient folder 9255/1 contains 388 images. Patient folder 9256/0 contains 489 images. Patient folder 9256/1 contains 442 images. Patient folder 9257/0 contains 1001 images. Patient folder 9257/1 contains 201 images. Patient folder 9258/0 contains 264 images. Patient folder 9258/1 contains 334 images. Patient folder 9259/0 contains 1174 images. Patient folder 9259/1 contains 225 images. Patient folder 9260/0 contains 236 images. Patient folder 9260/1 contains 126 images. Patient folder 9261/0 contains 521 images. Patient folder 9261/1 contains 446 images. Patient folder 9262/0 contains 14 images. Patient folder 9262/1 contains 80 images. Patient folder 9265/0 contains 1636 images. Patient folder 9265/1 contains 41 images. Patient folder 9266/0 contains 1119 images. Patient folder 9266/1 contains 52 images. Patient folder 9267/0 contains 337 images. Patient folder 9267/1 contains 321 images. Patient folder 9290/0 contains 1368 images. Patient folder 9290/1 contains 174 images. Patient folder 9291/0 contains 733 images. Patient folder 9291/1 contains 96 images. Patient folder 9319/0 contains 385 images. Patient folder 9319/1 contains 31 images. Patient folder 9320/0 contains 1453 images. Patient folder 9320/1 contains 451 images. Patient folder 9321/0 contains 282 images. Patient folder 9321/1 contains 30 images. Patient folder 9322/0 contains 1295 images. Patient folder 9322/1 contains 167 images. Patient folder 9323/0 contains 1938 images. Patient folder 9323/1 contains 278 images. Patient folder 9324/0 contains 720 images. Patient folder 9324/1 contains 322 images. Patient folder 9325/0 contains 1060 images. Patient folder 9325/1 contains 68 images. Patient folder 9344/0 contains 225 images. Patient folder 9344/1 contains 310 images. Patient folder 9345/0 contains 554 images. Patient folder 9345/1 contains 631 images. Patient folder 9346/0 contains 634 images. Patient folder 9346/1 contains 727 images. Patient folder 9347/0 contains 359 images. Patient folder 9347/1 contains 51 images. Patient folder 9381/0 contains 1198 images. Patient folder 9381/1 contains 128 images. Patient folder 9382/0 contains 1306 images. Patient folder 9382/1 contains 346 images. Patient folder 9383/0 contains 494 images. Patient folder 9383/1 contains 70 images.
print(f'Total number of images in all subfolders is {total_images}.')
Total number of images in all subfolders is 277524.
We have a total of 277,524 image samples in our dataset.
patient_folders = os.listdir(patient_data_folder) # List of patient folder names
# Initialize an empty list to collect data
data_list = []
# Iterate over patient folders and their subfolders
for patient_id in patient_folders:
patient_path = patient_data_folder / patient_id
for target in ['0', '1']:
class_path = patient_path / target
image_files = os.listdir(class_path)
for image_file in image_files:
image_path = class_path / image_file
# Append a new dict to the list
data_list.append({
"patient_id": patient_id,
"path": str(image_path),
"target": int(target)
})
# Create the DataFrame from the list of dicts
dataset = pd.DataFrame(data_list)
dataset.head()
| patient_id | path | target | |
|---|---|---|---|
| 0 | 10253 | ..\..\..\..\..\Downloads\hispathology\IDC_regu... | 0 |
| 1 | 10253 | ..\..\..\..\..\Downloads\hispathology\IDC_regu... | 0 |
| 2 | 10253 | ..\..\..\..\..\Downloads\hispathology\IDC_regu... | 0 |
| 3 | 10253 | ..\..\..\..\..\Downloads\hispathology\IDC_regu... | 0 |
| 4 | 10253 | ..\..\..\..\..\Downloads\hispathology\IDC_regu... | 0 |
dataset.shape
(277524, 3)
dataset.target.value_counts()
target 0 198738 1 78786 Name: count, dtype: int64
import plotly.graph_objects as go
# Calculate the percentage of cancerous patches for each patient
cancer_perc = dataset.groupby("patient_id")['target'].value_counts(normalize=True).unstack()
# Number of patches per patient
patches_per_patient = dataset.groupby("patient_id").size()
# Plot 1: Histogram of the number of patches per patient
fig1 = go.Figure(go.Histogram(x=patches_per_patient, nbinsx=30, marker_color="steelblue"))
fig1.update_layout(title_text="How many patches do we have per patient?", xaxis_title="Number of Patches", yaxis_title="Frequency", height=600, width=800)
fig1.update_traces(marker_line_width=1, marker_line_color="black", opacity=0.8)
fig1.show()
# Plot 2: Histogram of the percentage of cancerous patches per patient
# Ensure there is a '1' column for cancerous patches; if not, create it with 0 values
if 1 not in cancer_perc.columns:
cancer_perc[1] = 0
fig2 = go.Figure(go.Histogram(x=cancer_perc[1], nbinsx=30, marker_color="mediumseagreen"))
fig2.update_layout(title_text="How much percentage of an image is covered by IDC?", xaxis_title="% of Patches with IDC", yaxis_title="Frequency", height=600, width=800)
fig2.update_traces(marker_line_width=1, marker_line_color="black", opacity=0.8)
fig2.show()
# Plot 3: Count plot of non-cancerous vs. cancerous patches
fig3 = go.Figure(go.Histogram(x=dataset['target'], nbinsx=2, marker_color=["darkorchid", "darkorange"]))
fig3.update_layout(title_text="How many patches show IDC?", xaxis_title="No (0) vs Yes (1)", yaxis_title="Count", height=600, width=800)
fig3.update_traces(marker_line_width=1, marker_line_color="black", opacity=0.8)
fig3.show()
# plot pie chart of the distribution of classes
fig = go.Figure(data=[go.Pie(labels=['No IDC', 'IDC'], values=dataset.target.value_counts(), hole=.3)])
fig.update_layout(title_text="Distribution of Classes in the Dataset", height=600, width=800)
fig.show()
healthy = np.random.choice(dataset[dataset.target==0].index.values, size=50, replace=False)
non_healthy = np.random.choice(dataset[dataset.target==1].index.values, size=50, replace=False)
fig, ax = plt.subplots(5, 10, figsize=(20, 10))
for n in range(5):
for m in range(10):
idx = healthy[m + 10*n] # Index of the sample to display
image_path = dataset.loc[idx, "path"] # Get the path of the image
image = imread(image_path) # Load the image
ax[n, m].imshow(image) # Display the image
ax[n, m].axis('off') # Hide the axes
plt.show()
fig, ax = plt.subplots(5, 10, figsize=(20, 10))
for n in range(5):
for m in range(10):
idx = non_healthy[m + 10*n] # Index of the sample to display
image_path = dataset.loc[idx, "path"] # Get the path of the image
image = imread(image_path) # Load the image
ax[n, m].imshow(image) # Display the image
ax[n, m].axis('off') # Hide the axes
plt.show()
So far we have looked at each patch of the breast, now let us attempt to visualize the entire breast tissue.
import pandas as pd
import numpy as np
def extract_coords(df):
# Create a copy of the DataFrame to avoid SettingWithCopyWarning
df_copy = df.copy()
# Extract the x and y coordinates from the file paths
coord = df_copy['path'].str.rsplit("_", n=4, expand=True)
coord = coord.drop([0, 1, 4], axis=1) # Drop unused parts
coord = coord.rename({2: "x", 3: "y"}, axis=1) # Rename columns to 'x' and 'y'
coord['x'] = coord['x'].str.replace("x", "", case=False).astype(int) # Convert x values to integers
coord['y'] = coord['y'].str.replace("y", "", case=False).astype(int) # Convert y values to integers
# Merge the coordinates back into the original DataFrame using .loc to avoid SettingWithCopyWarning
df_copy.loc[:, 'x'] = coord['x']
df_copy.loc[:, 'y'] = coord['y']
return df_copy
def get_patient_dataframe(patient_id, dataset):
# Filter the dataset for the given patient_id
patient_df = dataset[dataset['patient_id'] == patient_id].copy()
# Extract coordinates and targets
patient_df = extract_coords(patient_df)
return patient_df
get_patient_dataframe("10253", dataset).head()
| patient_id | path | target | x | y | |
|---|---|---|---|---|---|
| 0 | 10253 | ..\..\..\..\..\Downloads\hispathology\IDC_regu... | 0 | 1001 | 1001 |
| 1 | 10253 | ..\..\..\..\..\Downloads\hispathology\IDC_regu... | 0 | 1001 | 1051 |
| 2 | 10253 | ..\..\..\..\..\Downloads\hispathology\IDC_regu... | 0 | 1001 | 1101 |
| 3 | 10253 | ..\..\..\..\..\Downloads\hispathology\IDC_regu... | 0 | 1001 | 1151 |
| 4 | 10253 | ..\..\..\..\..\Downloads\hispathology\IDC_regu... | 0 | 1001 | 1201 |
Binary target visualisation per tissue slice
fig, ax = plt.subplots(5, 3, figsize=(20, 27))
# Get unique patient IDs from your dataset
patient_ids = dataset.patient_id.unique()
for n in range(5):
for m in range(3):
patient_id = patient_ids[m + 3*n] # Select patient ID
example_df = get_patient_dataframe(patient_id, dataset) # Get the DataFrame for this patient
# Plot scatter plot of x-y coordinates colored by target
ax[n, m].scatter(example_df.x.values, example_df.y.values, c=example_df.target.values, cmap="coolwarm", s=20)
ax[n, m].set_title("patient " + patient_id)
ax[n, m].set_xlabel("y coord")
ax[n, m].set_ylabel("x coord")
Insights
Sometimes we don't have the full tissue information. It seems that tissue patches have been discarded or lost during preparation.
Reading the paper (link!) that seems to be related to this data this could also be part of the preprocessing.
After visualising the breast tissue images,
now it's time to go one step deeper with our EDA. Given the coordinates of image patches we could try to reconstruct the whole tissue image (not only the targets).
def visualise_breast_tissue(patient_id, pred_df=None):
example_df = get_patient_dataframe(patient_id, dataset)
max_point = [example_df.y.max() - 1, example_df.x.max() - 1]
grid = 255 * np.ones(shape=(max_point[0] + 50, max_point[1] + 50, 3)).astype(np.uint8)
mask = 255 * np.ones(shape=(max_point[0] + 50, max_point[1] + 50, 3)).astype(np.uint8)
if pred_df is not None:
patient_df = pred_df[pred_df['patient_id'] == patient_id].copy()
mask_proba = np.zeros(shape=(max_point[0] + 50, max_point[1] + 50, 1)).astype(float)
broken_patches = []
for n in range(len(example_df)):
try:
image = imread(example_df.path.values[n])
# Convert the image from normalized floats to uint8
image = (image * 255).astype(np.uint8)
target = example_df.target.values[n]
x_coord, y_coord = int(example_df.x.values[n]), int(example_df.y.values[n])
x_start, y_start = x_coord - 1, y_coord - 1
x_end, y_end = x_start + 50, y_start + 50
grid[y_start:y_end, x_start:x_end] = image
if target == 1:
mask[y_start:y_end, x_start:x_end, 0] = 250
mask[y_start:y_end, x_start:x_end, 1] = 0
mask[y_start:y_end, x_start:x_end, 2] = 0
if pred_df is not None:
proba = patient_df[(patient_df['x'] == x_coord) & (patient_df['y'] == y_coord)]['proba']
mask_proba[y_start:y_end, x_start:x_end, 0] = float(proba)
except ValueError:
broken_patches.append(example_df.iloc[n]['path'])
return grid, mask, broken_patches, mask_proba
patient_id = '10262'
grid, mask, broken_patches, mask_proba = visualise_breast_tissue(patient_id)
fig, ax = plt.subplots(1,2,figsize=(20,10))
ax[0].imshow(grid, alpha=0.9)
ax[1].imshow(mask, alpha=0.8)
ax[1].imshow(grid, alpha=0.7)
ax[0].grid(False)
ax[1].grid(False)
for m in range(2):
ax[m].set_xlabel("y-coord")
ax[m].set_ylabel("y-coord")
ax[0].set_title("Breast tissue slice of patient: " + patient_id)
ax[1].set_title("Cancer tissue colored red \n of patient: " + patient_id);
dataset
| patient_id | path | target | |
|---|---|---|---|
| 0 | 10253 | ..\..\..\..\..\Downloads\hispathology\IDC_regu... | 0 |
| 1 | 10253 | ..\..\..\..\..\Downloads\hispathology\IDC_regu... | 0 |
| 2 | 10253 | ..\..\..\..\..\Downloads\hispathology\IDC_regu... | 0 |
| 3 | 10253 | ..\..\..\..\..\Downloads\hispathology\IDC_regu... | 0 |
| 4 | 10253 | ..\..\..\..\..\Downloads\hispathology\IDC_regu... | 0 |
| ... | ... | ... | ... |
| 277519 | 9383 | ..\..\..\..\..\Downloads\hispathology\IDC_regu... | 1 |
| 277520 | 9383 | ..\..\..\..\..\Downloads\hispathology\IDC_regu... | 1 |
| 277521 | 9383 | ..\..\..\..\..\Downloads\hispathology\IDC_regu... | 1 |
| 277522 | 9383 | ..\..\..\..\..\Downloads\hispathology\IDC_regu... | 1 |
| 277523 | 9383 | ..\..\..\..\..\Downloads\hispathology\IDC_regu... | 1 |
277524 rows × 3 columns
One way we can try to handle the class imbalance is by taking an equal sample of each class. In this case, since we have exactly 78,786 positive samples, we will also only take 78,786 negative samples.
# Separate the dataset into two based on the target value
class_0_df = dataset[dataset['target'] == 0]
class_1_df = dataset[dataset['target'] == 1]
# Sample 2500 instances from each class
class_0_sample = class_0_df.sample(78786, random_state=42)
class_1_sample = class_1_df.sample(78786, random_state=42)
# Concatenate the two samples to create a balanced smaller dataset
small_dataset = pd.concat([class_0_sample, class_1_sample])
# Shuffle the small_dataset to mix class_0 and class_1 samples
small_dataset = small_dataset.sample(frac=1, random_state=42).reset_index(drop=True)
X = dataset.drop(columns=['target'])
y = dataset['target'].astype(str)
y = pd.DataFrame(y)
from sklearn.model_selection import train_test_split
# First, split into train and temp (temp will be further split into validation and test)
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, random_state=42)
# Now split the temp set into validation and test
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)
train_df = pd.concat([X_train, y_train], axis=1)
val_df = pd.concat([X_val, y_val], axis=1)
test_df = pd.concat([X_test, y_test], axis=1)
An alternative way would be to keep the samples as they are and then assign class weights to the classes during training, so the model would give more priority to the positive class.
from sklearn.utils import class_weight
import numpy as np
class_weights = class_weight.compute_class_weight(
class_weight='balanced',
classes=np.unique(y_train),
y=y_train['target'])
class_weights = dict(enumerate(class_weights))
class_weights
{0: 0.6985221674876847, 1: 1.7593052109181142}
We ended up using the second approach.
In this section, we are generating rgb values from our image dataset that would be fed into our models.
First we rescale our images. The other options were only applied to the training dataset.
We are also applying some data augmentation techniques to potentially improve the performance of our machine learning models.
from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(
rescale=1./255,
featurewise_center=False, # set input mean to 0 over the dataset
samplewise_center=False, # set each sample mean to 0
featurewise_std_normalization=False, # divide inputs by std of the dataset
samplewise_std_normalization=False, # divide each input by its std
zca_whitening=False, # apply ZCA whitening
rotation_range=20, # randomly rotate images in the range (degrees, 0 to 180)
width_shift_range=0.2, # randomly shift images horizontally (fraction of total width)
height_shift_range=0.2, # randomly shift images vertically (fraction of total height)
horizontal_flip=True, # randomly flip images
vertical_flip=True
)
val_datagen = ImageDataGenerator(rescale=1./255)
This is where we generate our train, valid and test generators from the function we defined above.
We also processed our images to a 180 x 180 pixel size, because that is what the model expects.
train_generator = train_datagen.flow_from_dataframe(
dataframe=train_df,
x_col="path",
y_col="target",
batch_size=32,
seed=42,
shuffle=True,
class_mode="categorical",
target_size=(180,180))
valid_generator = val_datagen.flow_from_dataframe(
dataframe=val_df,
x_col="path",
y_col="target",
batch_size=32,
seed=42,
shuffle=True,
class_mode="categorical",
target_size=(180,180))
test_generator = val_datagen.flow_from_dataframe(
dataframe=test_df,
x_col="path",
y_col="target",
batch_size=32,
seed=42,
shuffle=False,
class_mode="categorical",
target_size=(180,180))
Found 194266 validated image filenames belonging to 2 classes. Found 41629 validated image filenames belonging to 2 classes. Found 41629 validated image filenames belonging to 2 classes.
Here we use a simple convolutional neural network. It features two convolutional layers for feature extraction, followed by max pooling and dropout for regularization. The network includes a fully connected layer, another dropout stage, and a softmax output layer. It's compiled with categorical crossentropy loss and the Adadelta optimizer, trained with class weights to address class imbalance.
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(180, 180, 3)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(2, activation='softmax'))
model.compile(loss=keras.losses.categorical_crossentropy,
optimizer=keras.optimizers.Adadelta(),
metrics=['accuracy'])
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 178, 178, 32) 896
conv2d_1 (Conv2D) (None, 176, 176, 64) 18496
max_pooling2d (MaxPooling2D (None, 88, 88, 64) 0
)
dropout (Dropout) (None, 88, 88, 64) 0
flatten (Flatten) (None, 495616) 0
dense (Dense) (None, 128) 63438976
dropout_1 (Dropout) (None, 128) 0
dense_1 (Dense) (None, 2) 258
=================================================================
Total params: 63,458,626
Trainable params: 63,458,626
Non-trainable params: 0
_________________________________________________________________
callbacks = [
keras.callbacks.ModelCheckpoint(
filepath="./models/convnet_from_scratch_with_augmentation.keras",
save_best_only=True,
monitor="val_loss")
]
history = model.fit(
train_generator,
batch_size=1024,
epochs=20,
validation_data=valid_generator,
class_weight=class_weights,
callbacks=callbacks)
Epoch 1/20 6071/6071 [==============================] - 816s 133ms/step - loss: 0.5404 - accuracy: 0.7557 - val_loss: 0.4911 - val_accuracy: 0.7863 Epoch 2/20 6071/6071 [==============================] - 814s 134ms/step - loss: 0.4887 - accuracy: 0.7908 - val_loss: 0.4589 - val_accuracy: 0.8010 Epoch 3/20 6071/6071 [==============================] - 812s 134ms/step - loss: 0.4798 - accuracy: 0.7937 - val_loss: 0.4736 - val_accuracy: 0.7901 Epoch 4/20 6071/6071 [==============================] - 811s 134ms/step - loss: 0.4756 - accuracy: 0.7945 - val_loss: 0.4402 - val_accuracy: 0.8053 Epoch 5/20 6071/6071 [==============================] - 812s 134ms/step - loss: 0.4711 - accuracy: 0.7959 - val_loss: 0.4219 - val_accuracy: 0.8148 Epoch 6/20 6071/6071 [==============================] - 814s 134ms/step - loss: 0.4693 - accuracy: 0.7965 - val_loss: 0.4410 - val_accuracy: 0.8050 Epoch 7/20 6071/6071 [==============================] - 811s 134ms/step - loss: 0.4668 - accuracy: 0.7976 - val_loss: 0.4316 - val_accuracy: 0.8081 Epoch 8/20 6071/6071 [==============================] - 812s 134ms/step - loss: 0.4642 - accuracy: 0.7985 - val_loss: 0.4413 - val_accuracy: 0.8041 Epoch 9/20 6071/6071 [==============================] - 812s 134ms/step - loss: 0.4631 - accuracy: 0.7992 - val_loss: 0.4143 - val_accuracy: 0.8198 Epoch 10/20 6071/6071 [==============================] - 813s 134ms/step - loss: 0.4609 - accuracy: 0.7999 - val_loss: 0.4466 - val_accuracy: 0.7965 Epoch 11/20 6071/6071 [==============================] - 815s 134ms/step - loss: 0.4597 - accuracy: 0.8010 - val_loss: 0.4166 - val_accuracy: 0.8158 Epoch 12/20 6071/6071 [==============================] - 814s 134ms/step - loss: 0.4575 - accuracy: 0.8018 - val_loss: 0.4212 - val_accuracy: 0.8155 Epoch 13/20 6071/6071 [==============================] - 811s 134ms/step - loss: 0.4560 - accuracy: 0.8020 - val_loss: 0.4353 - val_accuracy: 0.8024 Epoch 14/20 6071/6071 [==============================] - 816s 134ms/step - loss: 0.4550 - accuracy: 0.8026 - val_loss: 0.4339 - val_accuracy: 0.8085 Epoch 15/20 6071/6071 [==============================] - 817s 135ms/step - loss: 0.4537 - accuracy: 0.8038 - val_loss: 0.4358 - val_accuracy: 0.8035 Epoch 16/20 6071/6071 [==============================] - 818s 135ms/step - loss: 0.4519 - accuracy: 0.8044 - val_loss: 0.4062 - val_accuracy: 0.8214 Epoch 17/20 6071/6071 [==============================] - 814s 134ms/step - loss: 0.4517 - accuracy: 0.8041 - val_loss: 0.4278 - val_accuracy: 0.8090 Epoch 18/20 6071/6071 [==============================] - 815s 134ms/step - loss: 0.4496 - accuracy: 0.8055 - val_loss: 0.3912 - val_accuracy: 0.8297 Epoch 19/20 6071/6071 [==============================] - 814s 134ms/step - loss: 0.4490 - accuracy: 0.8059 - val_loss: 0.4379 - val_accuracy: 0.8044 Epoch 20/20 6071/6071 [==============================] - 817s 135ms/step - loss: 0.4480 - accuracy: 0.8058 - val_loss: 0.4222 - val_accuracy: 0.8127
history_df = pd.DataFrame(history.history)
history_df.insert(0, 'epoch', range(1, len(history_df) + 1))
history_df
| epoch | loss | accuracy | val_loss | val_accuracy | |
|---|---|---|---|---|---|
| 0 | 1 | 0.540423 | 0.755742 | 0.491052 | 0.786327 |
| 1 | 2 | 0.488688 | 0.790828 | 0.458904 | 0.801004 |
| 2 | 3 | 0.479815 | 0.793685 | 0.473591 | 0.790122 |
| 3 | 4 | 0.475550 | 0.794473 | 0.440178 | 0.805256 |
| 4 | 5 | 0.471101 | 0.795940 | 0.421867 | 0.814793 |
| 5 | 6 | 0.469316 | 0.796485 | 0.441047 | 0.805040 |
| 6 | 7 | 0.466818 | 0.797602 | 0.431642 | 0.808139 |
| 7 | 8 | 0.464185 | 0.798508 | 0.441308 | 0.804055 |
| 8 | 9 | 0.463125 | 0.799167 | 0.414297 | 0.819789 |
| 9 | 10 | 0.460947 | 0.799857 | 0.446607 | 0.796536 |
| 10 | 11 | 0.459733 | 0.801015 | 0.416645 | 0.815753 |
| 11 | 12 | 0.457514 | 0.801772 | 0.421235 | 0.815513 |
| 12 | 13 | 0.456030 | 0.801973 | 0.435318 | 0.802373 |
| 13 | 14 | 0.455033 | 0.802606 | 0.433871 | 0.808547 |
| 14 | 15 | 0.453685 | 0.803795 | 0.435800 | 0.803478 |
| 15 | 16 | 0.451924 | 0.804356 | 0.406170 | 0.821447 |
| 16 | 17 | 0.451685 | 0.804104 | 0.427751 | 0.808979 |
| 17 | 18 | 0.449559 | 0.805524 | 0.391235 | 0.829734 |
| 18 | 19 | 0.449009 | 0.805946 | 0.437875 | 0.804391 |
| 19 | 20 | 0.448034 | 0.805828 | 0.422154 | 0.812655 |
# Create a DataFrame from the history object
history_df = pd.DataFrame(history.history)
# Plot the training and validation loss
plt.figure(figsize=(9, 5))
values = history_df['accuracy']
epochs = range(1, len(values) + 1)
plt.plot(epochs, history_df['loss'], 'bo', label='Training loss')
plt.plot(epochs, history_df['val_loss'], 'ro', label='Validation loss')
plt.xlabel('Epochs')
plt.xticks(epochs)
plt.ylabel('Loss')
plt.legend()
plt.title('Training and validation loss')
plt.show()
# Plot the training and validation accuracy
plt.figure(figsize=(9, 5))
plt.plot(epochs, history_df['accuracy'], 'bo', label='Training accuracy')
plt.plot(epochs, history_df['val_accuracy'], 'ro', label='Validation accuracy')
plt.xlabel('Epochs')
plt.xticks(epochs)
plt.ylabel('Accuracy')
plt.legend()
plt.title('Training and validation accuracy')
plt.show()
best_vanilla_model = load_model("./models/convnet_from_scratch_with_augmentation.keras")
# predict the model
y_pred_prob = best_vanilla_model.predict(test_generator)
# get the class with the highest probability
y_pred = np.argmax(y_pred_prob, axis=1)
# get the true class
y_true = test_generator.classes
y_true_array = np.array(y_true)
# get the class labels
class_labels = list(test_generator.class_indices.keys())
# get the classification report
display(pd.DataFrame(classification_report(y_true, y_pred, output_dict=True)).T)
# get the confusion matrix
cm = confusion_matrix(y_true, y_pred)
# plot the confusion matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_labels)
disp.plot(cmap='Blues')
# get the precision recall curve
postive_class_prob = y_pred_prob[:, 1]
precision, recall, _ = precision_recall_curve(y_true_array == 1, postive_class_prob)
# plot the precision recall curve
plt.figure(figsize=(9, 5))
plt.plot(recall, precision, "b-", linewidth=2)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.show()
# get the f1 score
f1 = f1_score(y_true, y_pred)
print(f'F1 Score: {f1}')
# get the accuracy
accuracy = accuracy_score(y_true, y_pred)
print(f'Accuracy: {accuracy}')
# get the precision
precision = precision_score(y_true, y_pred)
print(f'Precision: {precision}')
# get the recall
recall = recall_score(y_true, y_pred)
print(f'Recall: {recall}')
#from the confusion matrix, calculate tn, fp, fn, tp
tn, fp, fn, tp = cm.ravel()
print(f'True Negatives: {tn}')
print(f'False Positives: {fp}')
print(f'False Negatives: {fn}')
print(f'True Positives: {tp}')
# calculate the specificity
specificity = tn / (tn + fp)
print(f'Specificity: {specificity}')
1301/1301 [==============================] - 96s 74ms/step
| precision | recall | f1-score | support | |
|---|---|---|---|---|
| 0 | 0.883230 | 0.878100 | 0.880657 | 29959.000000 |
| 1 | 0.691658 | 0.701971 | 0.696776 | 11670.000000 |
| accuracy | 0.828725 | 0.828725 | 0.828725 | 0.828725 |
| macro avg | 0.787444 | 0.790035 | 0.788717 | 41629.000000 |
| weighted avg | 0.829526 | 0.828725 | 0.829109 | 41629.000000 |
F1 Score: 0.6967763885344901 Accuracy: 0.8287251675514665 Precision: 0.6916582235731172 Recall: 0.7019708654670094 True Negatives: 26307 False Positives: 3652 False Negatives: 3478 True Positives: 8192 Specificity: 0.8781000700957976
In this section, we finetuned the VGG16 model pre-trained on ImageNet. The base VGG16 layers are initially frozen, then the last four layers are fine-tuned to adapt to the specific dataset. It preprocesses our inputs for VGG16 compatibility, applies dropout for regularization, and includes a dense layer with dropout before the final softmax output. The model is compiled with categorical crossentropy and the Adadelta optimizer. It is trained using class weights to address imbalance, and saves the best model based on validation loss.
vgg_base = keras.applications.vgg16.VGG16(
weights="imagenet",
include_top=False,
input_shape=(180, 180, 3)
)
vgg_base.trainable = False
inputs = keras.Input(shape=(180, 180, 3))
x = keras.applications.vgg16.preprocess_input(inputs)
x = vgg_base(x)
x = layers.Dropout(0.25)(x)
x = layers.Flatten()(x)
x = layers.Dense(128, activation='relu')(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(2, activation="softmax")(x)
model_vgg_finetuned = keras.Model(inputs, outputs)
vgg_base.trainable = True
for layer in vgg_base.layers[:-4]:
layer.trainable = False
model_vgg_finetuned.summary()
Model: "model_6"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_14 (InputLayer) [(None, 180, 180, 3)] 0
tf.__operators__.getitem_6 (None, 180, 180, 3) 0
(SlicingOpLambda)
tf.nn.bias_add_6 (TFOpLambd (None, 180, 180, 3) 0
a)
vgg16 (Functional) (None, 5, 5, 512) 14714688
dropout_12 (Dropout) (None, 5, 5, 512) 0
flatten_7 (Flatten) (None, 12800) 0
dense_14 (Dense) (None, 128) 1638528
dropout_13 (Dropout) (None, 128) 0
dense_15 (Dense) (None, 2) 258
=================================================================
Total params: 16,353,474
Trainable params: 8,718,210
Non-trainable params: 7,635,264
_________________________________________________________________
model_vgg_finetuned.compile(loss="categorical_crossentropy",
optimizer=keras.optimizers.Adadelta(),
metrics=["accuracy"])
callbacks = [
keras.callbacks.ModelCheckpoint(
filepath="./models/finetune-vgg16.keras",
save_best_only=True,
monitor="val_loss")
]
history_vgg_finetuned = model_vgg_finetuned.fit(
train_generator,
epochs=12,
batch_size=32,
class_weight=class_weights,
validation_data=valid_generator,
callbacks=callbacks)
Epoch 1/12 6071/6071 [==============================] - 955s 157ms/step - loss: 0.7263 - accuracy: 0.5428 - val_loss: 0.6965 - val_accuracy: 0.5292 Epoch 2/12 6071/6071 [==============================] - 916s 151ms/step - loss: 0.6447 - accuracy: 0.6353 - val_loss: 0.6563 - val_accuracy: 0.6196 Epoch 3/12 6071/6071 [==============================] - 923s 152ms/step - loss: 0.6183 - accuracy: 0.6680 - val_loss: 0.7125 - val_accuracy: 0.5637 Epoch 4/12 6071/6071 [==============================] - 919s 151ms/step - loss: 0.5995 - accuracy: 0.6884 - val_loss: 0.6507 - val_accuracy: 0.6368 Epoch 5/12 6071/6071 [==============================] - 922s 152ms/step - loss: 0.5882 - accuracy: 0.6984 - val_loss: 0.7121 - val_accuracy: 0.5916 Epoch 6/12 6071/6071 [==============================] - 921s 152ms/step - loss: 0.5835 - accuracy: 0.7021 - val_loss: 0.6866 - val_accuracy: 0.6131 Epoch 7/12 6071/6071 [==============================] - 916s 151ms/step - loss: 0.5783 - accuracy: 0.7057 - val_loss: 0.6895 - val_accuracy: 0.6160 Epoch 8/12 6071/6071 [==============================] - 919s 151ms/step - loss: 0.5759 - accuracy: 0.7088 - val_loss: 0.5625 - val_accuracy: 0.7134 Epoch 9/12 6071/6071 [==============================] - 923s 152ms/step - loss: 0.5744 - accuracy: 0.7101 - val_loss: 0.6677 - val_accuracy: 0.6292 Epoch 10/12 6071/6071 [==============================] - 930s 153ms/step - loss: 0.5721 - accuracy: 0.7124 - val_loss: 0.6408 - val_accuracy: 0.6577 Epoch 11/12 6071/6071 [==============================] - 934s 154ms/step - loss: 0.5690 - accuracy: 0.7149 - val_loss: 0.6214 - val_accuracy: 0.6776 Epoch 12/12 6071/6071 [==============================] - 933s 154ms/step - loss: 0.5665 - accuracy: 0.7180 - val_loss: 0.7363 - val_accuracy: 0.5763
history_df_vgg_finetuned = pd.DataFrame(history_vgg_finetuned.history)
history_df_vgg_finetuned.insert(0, 'epoch', range(1, len(history_df_vgg_finetuned) + 1))
history_df_vgg_finetuned
| epoch | loss | accuracy | val_loss | val_accuracy | |
|---|---|---|---|---|---|
| 0 | 1 | 0.726279 | 0.542848 | 0.696479 | 0.529222 |
| 1 | 2 | 0.644665 | 0.635309 | 0.656277 | 0.619568 |
| 2 | 3 | 0.618331 | 0.667991 | 0.712548 | 0.563742 |
| 3 | 4 | 0.599538 | 0.688427 | 0.650696 | 0.636816 |
| 4 | 5 | 0.588227 | 0.698434 | 0.712100 | 0.591583 |
| 5 | 6 | 0.583525 | 0.702146 | 0.686592 | 0.613106 |
| 6 | 7 | 0.578265 | 0.705666 | 0.689502 | 0.616037 |
| 7 | 8 | 0.575941 | 0.708786 | 0.562464 | 0.713397 |
| 8 | 9 | 0.574428 | 0.710088 | 0.667731 | 0.629177 |
| 9 | 10 | 0.572060 | 0.712425 | 0.640814 | 0.657667 |
| 10 | 11 | 0.569045 | 0.714912 | 0.621354 | 0.677629 |
| 11 | 12 | 0.566484 | 0.718031 | 0.736303 | 0.576305 |
# Plot the training and validation loss
plt.figure(figsize=(9, 5))
values = history_df_vgg_finetuned['accuracy']
epochs = range(1, len(values) + 1)
plt.plot(epochs, history_df_vgg_finetuned['loss'], 'bo', label='Training loss')
plt.plot(epochs, history_df_vgg_finetuned['val_loss'], 'ro', label='Validation loss')
plt.xlabel('Epochs')
plt.xticks(epochs)
plt.ylabel('Loss')
plt.legend()
plt.title('Training and validation loss')
plt.show()
# Plot the training and validation accuracy
plt.figure(figsize=(9, 5))
plt.plot(epochs, history_df_vgg_finetuned['accuracy'], 'bo', label='Training accuracy')
plt.plot(epochs, history_df_vgg_finetuned['val_accuracy'], 'ro', label='Validation accuracy')
plt.xlabel('Epochs')
plt.xticks(epochs)
plt.ylabel('Accuracy')
plt.legend()
plt.title('Training and validation accuracy')
plt.show()
best_vgg_finetuned_model = load_model("./models/finetune-vgg16.keras")
# predict the model
y_pred_prob_vgg_finetuned = best_vgg_finetuned_model.predict(test_generator)
# get the class with the highest probability
y_pred_vgg_finetuned = np.argmax(y_pred_prob_vgg_finetuned, axis=1)
# get the true class
y_true = test_generator.classes
y_true_array = np.array(y_true)
# get the class labels
class_labels = list(test_generator.class_indices.keys())
# get the classification report
display(pd.DataFrame(classification_report(y_true, y_pred_vgg_finetuned, output_dict=True)).T)
# get the confusion matrix
cm_vgg_finetuned = confusion_matrix(y_true, y_pred_vgg_finetuned)
# plot the confusion matrix
disp_vgg_finetuned = ConfusionMatrixDisplay(confusion_matrix=cm_vgg_finetuned, display_labels=class_labels)
disp_vgg_finetuned.plot(cmap='Blues')
# get the precision recall curve of the positive class
postive_class_prob_vgg_finetuned = y_pred_prob_vgg_finetuned[:, 1]
precision_vgg_finetuned, recall_vgg_finetuned, _ = precision_recall_curve(y_true_array == 1, postive_class_prob_vgg_finetuned)
# plot the precision recall curve
plt.figure(figsize=(9, 5))
plt.plot(recall_vgg_finetuned, precision_vgg_finetuned, "b-", linewidth=2)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.show()
# get the f1 score
f1_vgg_finetuned = f1_score(y_true, y_pred_vgg_finetuned)
print(f'F1 Score: {f1}')
# get the accuracy
accuracy_vgg_finetuned = accuracy_score(y_true, y_pred_vgg_finetuned)
print(f'Accuracy: {accuracy_vgg_finetuned}')
# get the precision
precision_vgg_finetuned = precision_score(y_true, y_pred_vgg_finetuned)
print(f'Precision: {precision_vgg_finetuned}')
# get the recall
recall_vgg_finetuned = recall_score(y_true, y_pred_vgg_finetuned)
print(f'Recall: {recall_vgg_finetuned}')
#from the confusion matrix, calculate tn, fp, fn, tp
tn_vgg_finetuned, fp_vgg_finetuned, fn_vgg_finetuned, tp_vgg_finetuned = cm.ravel()
print(f'True Negatives: {tn_vgg_finetuned}')
print(f'False Positives: {fp_vgg_finetuned}')
print(f'False Negatives: {fn_vgg_finetuned}')
print(f'True Positives: {tp_vgg_finetuned}')
# calculate the specificity
specificity_vgg_finetuned = tn_vgg_finetuned / (tn_vgg_finetuned + fp_vgg_finetuned)
print(f'Specificity: {specificity_vgg_finetuned}')
1/1301 [..............................] - ETA: 1:101301/1301 [==============================] - 65s 50ms/step
| precision | recall | f1-score | support | |
|---|---|---|---|---|
| 0 | 0.867835 | 0.707500 | 0.779508 | 29959.000000 |
| 1 | 0.490671 | 0.723393 | 0.584727 | 11670.000000 |
| accuracy | 0.711956 | 0.711956 | 0.711956 | 0.711956 |
| macro avg | 0.679253 | 0.715447 | 0.682118 | 41629.000000 |
| weighted avg | 0.762103 | 0.711956 | 0.724905 | 41629.000000 |
F1 Score: 0.6967763885344901 Accuracy: 0.7119556078695141 Precision: 0.49067131647776807 Recall: 0.7233933161953727 True Negatives: 26307 False Positives: 3652 False Negatives: 3478 True Positives: 8192 Specificity: 0.8781000700957976
Here, this model integrates the ResNet50 architecture, pre-trained on ImageNet. The base ResNet50 layers are frozen to retain learned features, while additional custom layers include flattening, a dense layer, and dropout to combat overfitting. The final layer uses softmax for binary classification output. It's compiled with categorical crossentropy loss and the Adam optimizer. Training involves data augmentation, class weights for balancing, and callbacks to save improvements based on validation loss.
restnet_base = keras.applications.ResNet50(
weights="imagenet",
include_top=False,
input_shape=(180, 180, 3)
)
restnet_base.summary()
Model: "resnet50"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_15 (InputLayer) [(None, 180, 180, 3 0 []
)]
conv1_pad (ZeroPadding2D) (None, 186, 186, 3) 0 ['input_15[0][0]']
conv1_conv (Conv2D) (None, 90, 90, 64) 9472 ['conv1_pad[0][0]']
conv1_bn (BatchNormalization) (None, 90, 90, 64) 256 ['conv1_conv[0][0]']
conv1_relu (Activation) (None, 90, 90, 64) 0 ['conv1_bn[0][0]']
pool1_pad (ZeroPadding2D) (None, 92, 92, 64) 0 ['conv1_relu[0][0]']
pool1_pool (MaxPooling2D) (None, 45, 45, 64) 0 ['pool1_pad[0][0]']
conv2_block1_1_conv (Conv2D) (None, 45, 45, 64) 4160 ['pool1_pool[0][0]']
conv2_block1_1_bn (BatchNormal (None, 45, 45, 64) 256 ['conv2_block1_1_conv[0][0]']
ization)
conv2_block1_1_relu (Activatio (None, 45, 45, 64) 0 ['conv2_block1_1_bn[0][0]']
n)
conv2_block1_2_conv (Conv2D) (None, 45, 45, 64) 36928 ['conv2_block1_1_relu[0][0]']
conv2_block1_2_bn (BatchNormal (None, 45, 45, 64) 256 ['conv2_block1_2_conv[0][0]']
ization)
conv2_block1_2_relu (Activatio (None, 45, 45, 64) 0 ['conv2_block1_2_bn[0][0]']
n)
conv2_block1_0_conv (Conv2D) (None, 45, 45, 256) 16640 ['pool1_pool[0][0]']
conv2_block1_3_conv (Conv2D) (None, 45, 45, 256) 16640 ['conv2_block1_2_relu[0][0]']
conv2_block1_0_bn (BatchNormal (None, 45, 45, 256) 1024 ['conv2_block1_0_conv[0][0]']
ization)
conv2_block1_3_bn (BatchNormal (None, 45, 45, 256) 1024 ['conv2_block1_3_conv[0][0]']
ization)
conv2_block1_add (Add) (None, 45, 45, 256) 0 ['conv2_block1_0_bn[0][0]',
'conv2_block1_3_bn[0][0]']
conv2_block1_out (Activation) (None, 45, 45, 256) 0 ['conv2_block1_add[0][0]']
conv2_block2_1_conv (Conv2D) (None, 45, 45, 64) 16448 ['conv2_block1_out[0][0]']
conv2_block2_1_bn (BatchNormal (None, 45, 45, 64) 256 ['conv2_block2_1_conv[0][0]']
ization)
conv2_block2_1_relu (Activatio (None, 45, 45, 64) 0 ['conv2_block2_1_bn[0][0]']
n)
conv2_block2_2_conv (Conv2D) (None, 45, 45, 64) 36928 ['conv2_block2_1_relu[0][0]']
conv2_block2_2_bn (BatchNormal (None, 45, 45, 64) 256 ['conv2_block2_2_conv[0][0]']
ization)
conv2_block2_2_relu (Activatio (None, 45, 45, 64) 0 ['conv2_block2_2_bn[0][0]']
n)
conv2_block2_3_conv (Conv2D) (None, 45, 45, 256) 16640 ['conv2_block2_2_relu[0][0]']
conv2_block2_3_bn (BatchNormal (None, 45, 45, 256) 1024 ['conv2_block2_3_conv[0][0]']
ization)
conv2_block2_add (Add) (None, 45, 45, 256) 0 ['conv2_block1_out[0][0]',
'conv2_block2_3_bn[0][0]']
conv2_block2_out (Activation) (None, 45, 45, 256) 0 ['conv2_block2_add[0][0]']
conv2_block3_1_conv (Conv2D) (None, 45, 45, 64) 16448 ['conv2_block2_out[0][0]']
conv2_block3_1_bn (BatchNormal (None, 45, 45, 64) 256 ['conv2_block3_1_conv[0][0]']
ization)
conv2_block3_1_relu (Activatio (None, 45, 45, 64) 0 ['conv2_block3_1_bn[0][0]']
n)
conv2_block3_2_conv (Conv2D) (None, 45, 45, 64) 36928 ['conv2_block3_1_relu[0][0]']
conv2_block3_2_bn (BatchNormal (None, 45, 45, 64) 256 ['conv2_block3_2_conv[0][0]']
ization)
conv2_block3_2_relu (Activatio (None, 45, 45, 64) 0 ['conv2_block3_2_bn[0][0]']
n)
conv2_block3_3_conv (Conv2D) (None, 45, 45, 256) 16640 ['conv2_block3_2_relu[0][0]']
conv2_block3_3_bn (BatchNormal (None, 45, 45, 256) 1024 ['conv2_block3_3_conv[0][0]']
ization)
conv2_block3_add (Add) (None, 45, 45, 256) 0 ['conv2_block2_out[0][0]',
'conv2_block3_3_bn[0][0]']
conv2_block3_out (Activation) (None, 45, 45, 256) 0 ['conv2_block3_add[0][0]']
conv3_block1_1_conv (Conv2D) (None, 23, 23, 128) 32896 ['conv2_block3_out[0][0]']
conv3_block1_1_bn (BatchNormal (None, 23, 23, 128) 512 ['conv3_block1_1_conv[0][0]']
ization)
conv3_block1_1_relu (Activatio (None, 23, 23, 128) 0 ['conv3_block1_1_bn[0][0]']
n)
conv3_block1_2_conv (Conv2D) (None, 23, 23, 128) 147584 ['conv3_block1_1_relu[0][0]']
conv3_block1_2_bn (BatchNormal (None, 23, 23, 128) 512 ['conv3_block1_2_conv[0][0]']
ization)
conv3_block1_2_relu (Activatio (None, 23, 23, 128) 0 ['conv3_block1_2_bn[0][0]']
n)
conv3_block1_0_conv (Conv2D) (None, 23, 23, 512) 131584 ['conv2_block3_out[0][0]']
conv3_block1_3_conv (Conv2D) (None, 23, 23, 512) 66048 ['conv3_block1_2_relu[0][0]']
conv3_block1_0_bn (BatchNormal (None, 23, 23, 512) 2048 ['conv3_block1_0_conv[0][0]']
ization)
conv3_block1_3_bn (BatchNormal (None, 23, 23, 512) 2048 ['conv3_block1_3_conv[0][0]']
ization)
conv3_block1_add (Add) (None, 23, 23, 512) 0 ['conv3_block1_0_bn[0][0]',
'conv3_block1_3_bn[0][0]']
conv3_block1_out (Activation) (None, 23, 23, 512) 0 ['conv3_block1_add[0][0]']
conv3_block2_1_conv (Conv2D) (None, 23, 23, 128) 65664 ['conv3_block1_out[0][0]']
conv3_block2_1_bn (BatchNormal (None, 23, 23, 128) 512 ['conv3_block2_1_conv[0][0]']
ization)
conv3_block2_1_relu (Activatio (None, 23, 23, 128) 0 ['conv3_block2_1_bn[0][0]']
n)
conv3_block2_2_conv (Conv2D) (None, 23, 23, 128) 147584 ['conv3_block2_1_relu[0][0]']
conv3_block2_2_bn (BatchNormal (None, 23, 23, 128) 512 ['conv3_block2_2_conv[0][0]']
ization)
conv3_block2_2_relu (Activatio (None, 23, 23, 128) 0 ['conv3_block2_2_bn[0][0]']
n)
conv3_block2_3_conv (Conv2D) (None, 23, 23, 512) 66048 ['conv3_block2_2_relu[0][0]']
conv3_block2_3_bn (BatchNormal (None, 23, 23, 512) 2048 ['conv3_block2_3_conv[0][0]']
ization)
conv3_block2_add (Add) (None, 23, 23, 512) 0 ['conv3_block1_out[0][0]',
'conv3_block2_3_bn[0][0]']
conv3_block2_out (Activation) (None, 23, 23, 512) 0 ['conv3_block2_add[0][0]']
conv3_block3_1_conv (Conv2D) (None, 23, 23, 128) 65664 ['conv3_block2_out[0][0]']
conv3_block3_1_bn (BatchNormal (None, 23, 23, 128) 512 ['conv3_block3_1_conv[0][0]']
ization)
conv3_block3_1_relu (Activatio (None, 23, 23, 128) 0 ['conv3_block3_1_bn[0][0]']
n)
conv3_block3_2_conv (Conv2D) (None, 23, 23, 128) 147584 ['conv3_block3_1_relu[0][0]']
conv3_block3_2_bn (BatchNormal (None, 23, 23, 128) 512 ['conv3_block3_2_conv[0][0]']
ization)
conv3_block3_2_relu (Activatio (None, 23, 23, 128) 0 ['conv3_block3_2_bn[0][0]']
n)
conv3_block3_3_conv (Conv2D) (None, 23, 23, 512) 66048 ['conv3_block3_2_relu[0][0]']
conv3_block3_3_bn (BatchNormal (None, 23, 23, 512) 2048 ['conv3_block3_3_conv[0][0]']
ization)
conv3_block3_add (Add) (None, 23, 23, 512) 0 ['conv3_block2_out[0][0]',
'conv3_block3_3_bn[0][0]']
conv3_block3_out (Activation) (None, 23, 23, 512) 0 ['conv3_block3_add[0][0]']
conv3_block4_1_conv (Conv2D) (None, 23, 23, 128) 65664 ['conv3_block3_out[0][0]']
conv3_block4_1_bn (BatchNormal (None, 23, 23, 128) 512 ['conv3_block4_1_conv[0][0]']
ization)
conv3_block4_1_relu (Activatio (None, 23, 23, 128) 0 ['conv3_block4_1_bn[0][0]']
n)
conv3_block4_2_conv (Conv2D) (None, 23, 23, 128) 147584 ['conv3_block4_1_relu[0][0]']
conv3_block4_2_bn (BatchNormal (None, 23, 23, 128) 512 ['conv3_block4_2_conv[0][0]']
ization)
conv3_block4_2_relu (Activatio (None, 23, 23, 128) 0 ['conv3_block4_2_bn[0][0]']
n)
conv3_block4_3_conv (Conv2D) (None, 23, 23, 512) 66048 ['conv3_block4_2_relu[0][0]']
conv3_block4_3_bn (BatchNormal (None, 23, 23, 512) 2048 ['conv3_block4_3_conv[0][0]']
ization)
conv3_block4_add (Add) (None, 23, 23, 512) 0 ['conv3_block3_out[0][0]',
'conv3_block4_3_bn[0][0]']
conv3_block4_out (Activation) (None, 23, 23, 512) 0 ['conv3_block4_add[0][0]']
conv4_block1_1_conv (Conv2D) (None, 12, 12, 256) 131328 ['conv3_block4_out[0][0]']
conv4_block1_1_bn (BatchNormal (None, 12, 12, 256) 1024 ['conv4_block1_1_conv[0][0]']
ization)
conv4_block1_1_relu (Activatio (None, 12, 12, 256) 0 ['conv4_block1_1_bn[0][0]']
n)
conv4_block1_2_conv (Conv2D) (None, 12, 12, 256) 590080 ['conv4_block1_1_relu[0][0]']
conv4_block1_2_bn (BatchNormal (None, 12, 12, 256) 1024 ['conv4_block1_2_conv[0][0]']
ization)
conv4_block1_2_relu (Activatio (None, 12, 12, 256) 0 ['conv4_block1_2_bn[0][0]']
n)
conv4_block1_0_conv (Conv2D) (None, 12, 12, 1024 525312 ['conv3_block4_out[0][0]']
)
conv4_block1_3_conv (Conv2D) (None, 12, 12, 1024 263168 ['conv4_block1_2_relu[0][0]']
)
conv4_block1_0_bn (BatchNormal (None, 12, 12, 1024 4096 ['conv4_block1_0_conv[0][0]']
ization) )
conv4_block1_3_bn (BatchNormal (None, 12, 12, 1024 4096 ['conv4_block1_3_conv[0][0]']
ization) )
conv4_block1_add (Add) (None, 12, 12, 1024 0 ['conv4_block1_0_bn[0][0]',
) 'conv4_block1_3_bn[0][0]']
conv4_block1_out (Activation) (None, 12, 12, 1024 0 ['conv4_block1_add[0][0]']
)
conv4_block2_1_conv (Conv2D) (None, 12, 12, 256) 262400 ['conv4_block1_out[0][0]']
conv4_block2_1_bn (BatchNormal (None, 12, 12, 256) 1024 ['conv4_block2_1_conv[0][0]']
ization)
conv4_block2_1_relu (Activatio (None, 12, 12, 256) 0 ['conv4_block2_1_bn[0][0]']
n)
conv4_block2_2_conv (Conv2D) (None, 12, 12, 256) 590080 ['conv4_block2_1_relu[0][0]']
conv4_block2_2_bn (BatchNormal (None, 12, 12, 256) 1024 ['conv4_block2_2_conv[0][0]']
ization)
conv4_block2_2_relu (Activatio (None, 12, 12, 256) 0 ['conv4_block2_2_bn[0][0]']
n)
conv4_block2_3_conv (Conv2D) (None, 12, 12, 1024 263168 ['conv4_block2_2_relu[0][0]']
)
conv4_block2_3_bn (BatchNormal (None, 12, 12, 1024 4096 ['conv4_block2_3_conv[0][0]']
ization) )
conv4_block2_add (Add) (None, 12, 12, 1024 0 ['conv4_block1_out[0][0]',
) 'conv4_block2_3_bn[0][0]']
conv4_block2_out (Activation) (None, 12, 12, 1024 0 ['conv4_block2_add[0][0]']
)
conv4_block3_1_conv (Conv2D) (None, 12, 12, 256) 262400 ['conv4_block2_out[0][0]']
conv4_block3_1_bn (BatchNormal (None, 12, 12, 256) 1024 ['conv4_block3_1_conv[0][0]']
ization)
conv4_block3_1_relu (Activatio (None, 12, 12, 256) 0 ['conv4_block3_1_bn[0][0]']
n)
conv4_block3_2_conv (Conv2D) (None, 12, 12, 256) 590080 ['conv4_block3_1_relu[0][0]']
conv4_block3_2_bn (BatchNormal (None, 12, 12, 256) 1024 ['conv4_block3_2_conv[0][0]']
ization)
conv4_block3_2_relu (Activatio (None, 12, 12, 256) 0 ['conv4_block3_2_bn[0][0]']
n)
conv4_block3_3_conv (Conv2D) (None, 12, 12, 1024 263168 ['conv4_block3_2_relu[0][0]']
)
conv4_block3_3_bn (BatchNormal (None, 12, 12, 1024 4096 ['conv4_block3_3_conv[0][0]']
ization) )
conv4_block3_add (Add) (None, 12, 12, 1024 0 ['conv4_block2_out[0][0]',
) 'conv4_block3_3_bn[0][0]']
conv4_block3_out (Activation) (None, 12, 12, 1024 0 ['conv4_block3_add[0][0]']
)
conv4_block4_1_conv (Conv2D) (None, 12, 12, 256) 262400 ['conv4_block3_out[0][0]']
conv4_block4_1_bn (BatchNormal (None, 12, 12, 256) 1024 ['conv4_block4_1_conv[0][0]']
ization)
conv4_block4_1_relu (Activatio (None, 12, 12, 256) 0 ['conv4_block4_1_bn[0][0]']
n)
conv4_block4_2_conv (Conv2D) (None, 12, 12, 256) 590080 ['conv4_block4_1_relu[0][0]']
conv4_block4_2_bn (BatchNormal (None, 12, 12, 256) 1024 ['conv4_block4_2_conv[0][0]']
ization)
conv4_block4_2_relu (Activatio (None, 12, 12, 256) 0 ['conv4_block4_2_bn[0][0]']
n)
conv4_block4_3_conv (Conv2D) (None, 12, 12, 1024 263168 ['conv4_block4_2_relu[0][0]']
)
conv4_block4_3_bn (BatchNormal (None, 12, 12, 1024 4096 ['conv4_block4_3_conv[0][0]']
ization) )
conv4_block4_add (Add) (None, 12, 12, 1024 0 ['conv4_block3_out[0][0]',
) 'conv4_block4_3_bn[0][0]']
conv4_block4_out (Activation) (None, 12, 12, 1024 0 ['conv4_block4_add[0][0]']
)
conv4_block5_1_conv (Conv2D) (None, 12, 12, 256) 262400 ['conv4_block4_out[0][0]']
conv4_block5_1_bn (BatchNormal (None, 12, 12, 256) 1024 ['conv4_block5_1_conv[0][0]']
ization)
conv4_block5_1_relu (Activatio (None, 12, 12, 256) 0 ['conv4_block5_1_bn[0][0]']
n)
conv4_block5_2_conv (Conv2D) (None, 12, 12, 256) 590080 ['conv4_block5_1_relu[0][0]']
conv4_block5_2_bn (BatchNormal (None, 12, 12, 256) 1024 ['conv4_block5_2_conv[0][0]']
ization)
conv4_block5_2_relu (Activatio (None, 12, 12, 256) 0 ['conv4_block5_2_bn[0][0]']
n)
conv4_block5_3_conv (Conv2D) (None, 12, 12, 1024 263168 ['conv4_block5_2_relu[0][0]']
)
conv4_block5_3_bn (BatchNormal (None, 12, 12, 1024 4096 ['conv4_block5_3_conv[0][0]']
ization) )
conv4_block5_add (Add) (None, 12, 12, 1024 0 ['conv4_block4_out[0][0]',
) 'conv4_block5_3_bn[0][0]']
conv4_block5_out (Activation) (None, 12, 12, 1024 0 ['conv4_block5_add[0][0]']
)
conv4_block6_1_conv (Conv2D) (None, 12, 12, 256) 262400 ['conv4_block5_out[0][0]']
conv4_block6_1_bn (BatchNormal (None, 12, 12, 256) 1024 ['conv4_block6_1_conv[0][0]']
ization)
conv4_block6_1_relu (Activatio (None, 12, 12, 256) 0 ['conv4_block6_1_bn[0][0]']
n)
conv4_block6_2_conv (Conv2D) (None, 12, 12, 256) 590080 ['conv4_block6_1_relu[0][0]']
conv4_block6_2_bn (BatchNormal (None, 12, 12, 256) 1024 ['conv4_block6_2_conv[0][0]']
ization)
conv4_block6_2_relu (Activatio (None, 12, 12, 256) 0 ['conv4_block6_2_bn[0][0]']
n)
conv4_block6_3_conv (Conv2D) (None, 12, 12, 1024 263168 ['conv4_block6_2_relu[0][0]']
)
conv4_block6_3_bn (BatchNormal (None, 12, 12, 1024 4096 ['conv4_block6_3_conv[0][0]']
ization) )
conv4_block6_add (Add) (None, 12, 12, 1024 0 ['conv4_block5_out[0][0]',
) 'conv4_block6_3_bn[0][0]']
conv4_block6_out (Activation) (None, 12, 12, 1024 0 ['conv4_block6_add[0][0]']
)
conv5_block1_1_conv (Conv2D) (None, 6, 6, 512) 524800 ['conv4_block6_out[0][0]']
conv5_block1_1_bn (BatchNormal (None, 6, 6, 512) 2048 ['conv5_block1_1_conv[0][0]']
ization)
conv5_block1_1_relu (Activatio (None, 6, 6, 512) 0 ['conv5_block1_1_bn[0][0]']
n)
conv5_block1_2_conv (Conv2D) (None, 6, 6, 512) 2359808 ['conv5_block1_1_relu[0][0]']
conv5_block1_2_bn (BatchNormal (None, 6, 6, 512) 2048 ['conv5_block1_2_conv[0][0]']
ization)
conv5_block1_2_relu (Activatio (None, 6, 6, 512) 0 ['conv5_block1_2_bn[0][0]']
n)
conv5_block1_0_conv (Conv2D) (None, 6, 6, 2048) 2099200 ['conv4_block6_out[0][0]']
conv5_block1_3_conv (Conv2D) (None, 6, 6, 2048) 1050624 ['conv5_block1_2_relu[0][0]']
conv5_block1_0_bn (BatchNormal (None, 6, 6, 2048) 8192 ['conv5_block1_0_conv[0][0]']
ization)
conv5_block1_3_bn (BatchNormal (None, 6, 6, 2048) 8192 ['conv5_block1_3_conv[0][0]']
ization)
conv5_block1_add (Add) (None, 6, 6, 2048) 0 ['conv5_block1_0_bn[0][0]',
'conv5_block1_3_bn[0][0]']
conv5_block1_out (Activation) (None, 6, 6, 2048) 0 ['conv5_block1_add[0][0]']
conv5_block2_1_conv (Conv2D) (None, 6, 6, 512) 1049088 ['conv5_block1_out[0][0]']
conv5_block2_1_bn (BatchNormal (None, 6, 6, 512) 2048 ['conv5_block2_1_conv[0][0]']
ization)
conv5_block2_1_relu (Activatio (None, 6, 6, 512) 0 ['conv5_block2_1_bn[0][0]']
n)
conv5_block2_2_conv (Conv2D) (None, 6, 6, 512) 2359808 ['conv5_block2_1_relu[0][0]']
conv5_block2_2_bn (BatchNormal (None, 6, 6, 512) 2048 ['conv5_block2_2_conv[0][0]']
ization)
conv5_block2_2_relu (Activatio (None, 6, 6, 512) 0 ['conv5_block2_2_bn[0][0]']
n)
conv5_block2_3_conv (Conv2D) (None, 6, 6, 2048) 1050624 ['conv5_block2_2_relu[0][0]']
conv5_block2_3_bn (BatchNormal (None, 6, 6, 2048) 8192 ['conv5_block2_3_conv[0][0]']
ization)
conv5_block2_add (Add) (None, 6, 6, 2048) 0 ['conv5_block1_out[0][0]',
'conv5_block2_3_bn[0][0]']
conv5_block2_out (Activation) (None, 6, 6, 2048) 0 ['conv5_block2_add[0][0]']
conv5_block3_1_conv (Conv2D) (None, 6, 6, 512) 1049088 ['conv5_block2_out[0][0]']
conv5_block3_1_bn (BatchNormal (None, 6, 6, 512) 2048 ['conv5_block3_1_conv[0][0]']
ization)
conv5_block3_1_relu (Activatio (None, 6, 6, 512) 0 ['conv5_block3_1_bn[0][0]']
n)
conv5_block3_2_conv (Conv2D) (None, 6, 6, 512) 2359808 ['conv5_block3_1_relu[0][0]']
conv5_block3_2_bn (BatchNormal (None, 6, 6, 512) 2048 ['conv5_block3_2_conv[0][0]']
ization)
conv5_block3_2_relu (Activatio (None, 6, 6, 512) 0 ['conv5_block3_2_bn[0][0]']
n)
conv5_block3_3_conv (Conv2D) (None, 6, 6, 2048) 1050624 ['conv5_block3_2_relu[0][0]']
conv5_block3_3_bn (BatchNormal (None, 6, 6, 2048) 8192 ['conv5_block3_3_conv[0][0]']
ization)
conv5_block3_add (Add) (None, 6, 6, 2048) 0 ['conv5_block2_out[0][0]',
'conv5_block3_3_bn[0][0]']
conv5_block3_out (Activation) (None, 6, 6, 2048) 0 ['conv5_block3_add[0][0]']
==================================================================================================
Total params: 23,587,712
Trainable params: 23,534,592
Non-trainable params: 53,120
__________________________________________________________________________________________________
restnet_base.trainable = False
inputs = keras.Input(shape=(180, 180, 3))
x = keras.applications.resnet50.preprocess_input(inputs)
x = restnet_base(x)
x = layers.Flatten()(x)
x = layers.Dense(256)(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(2, activation="softmax")(x)
model_restnet = keras.Model(inputs, outputs)
model_restnet.summary()
Model: "model_7"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_16 (InputLayer) [(None, 180, 180, 3)] 0
tf.__operators__.getitem_7 (None, 180, 180, 3) 0
(SlicingOpLambda)
tf.nn.bias_add_7 (TFOpLambd (None, 180, 180, 3) 0
a)
resnet50 (Functional) (None, 6, 6, 2048) 23587712
flatten_8 (Flatten) (None, 73728) 0
dense_16 (Dense) (None, 256) 18874624
dropout_14 (Dropout) (None, 256) 0
dense_17 (Dense) (None, 2) 514
=================================================================
Total params: 42,462,850
Trainable params: 18,875,138
Non-trainable params: 23,587,712
_________________________________________________________________
model_restnet.compile(loss="categorical_crossentropy",
optimizer='adam',
metrics=["accuracy"])
callbacks = [
keras.callbacks.ModelCheckpoint(
filepath="./models/finetune-restnet.keras",
save_best_only=True,
monitor="val_loss")
]
history_restnet = model_restnet.fit(
train_generator,
epochs=12,
batch_size=1024,
validation_data=valid_generator,
class_weight=class_weights,
callbacks=callbacks)
Epoch 1/12 6071/6071 [==============================] - 741s 122ms/step - loss: 1.3605 - accuracy: 0.6478 - val_loss: 0.9245 - val_accuracy: 0.4740 Epoch 2/12 6071/6071 [==============================] - 769s 127ms/step - loss: 0.7021 - accuracy: 0.6501 - val_loss: 0.7270 - val_accuracy: 0.6208 Epoch 3/12 6071/6071 [==============================] - 768s 126ms/step - loss: 0.6973 - accuracy: 0.6623 - val_loss: 0.4647 - val_accuracy: 0.7883 Epoch 4/12 6071/6071 [==============================] - 1053s 173ms/step - loss: 0.6910 - accuracy: 0.6705 - val_loss: 0.8821 - val_accuracy: 0.6563 Epoch 5/12 6071/6071 [==============================] - 820s 135ms/step - loss: 0.6906 - accuracy: 0.6748 - val_loss: 0.4949 - val_accuracy: 0.7684 Epoch 6/12 6071/6071 [==============================] - 758s 125ms/step - loss: 0.6848 - accuracy: 0.6795 - val_loss: 0.5071 - val_accuracy: 0.7624 Epoch 7/12 6071/6071 [==============================] - 757s 125ms/step - loss: 0.6772 - accuracy: 0.6841 - val_loss: 0.6249 - val_accuracy: 0.6772 Epoch 8/12 6071/6071 [==============================] - 760s 125ms/step - loss: 0.6799 - accuracy: 0.6840 - val_loss: 0.4475 - val_accuracy: 0.7981 Epoch 9/12 6071/6071 [==============================] - 1136s 187ms/step - loss: 0.6779 - accuracy: 0.6854 - val_loss: 0.4624 - val_accuracy: 0.7817 Epoch 10/12 6071/6071 [==============================] - 1006s 166ms/step - loss: 0.6808 - accuracy: 0.6869 - val_loss: 1.2335 - val_accuracy: 0.5147 Epoch 11/12 6071/6071 [==============================] - 772s 127ms/step - loss: 0.6722 - accuracy: 0.6906 - val_loss: 0.4552 - val_accuracy: 0.7937 Epoch 12/12 6071/6071 [==============================] - 775s 128ms/step - loss: 0.6768 - accuracy: 0.6917 - val_loss: 1.0478 - val_accuracy: 0.5238
history_df_restnet = pd.DataFrame(history_restnet.history)
history_df_restnet.insert(0, 'epoch', range(1, len(history_df_restnet) + 1))
history_df_restnet
| epoch | loss | accuracy | val_loss | val_accuracy | |
|---|---|---|---|---|---|
| 0 | 1 | 1.360464 | 0.647756 | 0.924516 | 0.473996 |
| 1 | 2 | 0.702105 | 0.650103 | 0.727046 | 0.620841 |
| 2 | 3 | 0.697281 | 0.662252 | 0.464693 | 0.788321 |
| 3 | 4 | 0.690979 | 0.670478 | 0.882149 | 0.656297 |
| 4 | 5 | 0.690584 | 0.674796 | 0.494930 | 0.768431 |
| 5 | 6 | 0.684795 | 0.679512 | 0.507101 | 0.762353 |
| 6 | 7 | 0.677205 | 0.684098 | 0.624910 | 0.677196 |
| 7 | 8 | 0.679915 | 0.683969 | 0.447450 | 0.798073 |
| 8 | 9 | 0.677909 | 0.685421 | 0.462432 | 0.781715 |
| 9 | 10 | 0.680843 | 0.686924 | 1.233491 | 0.514689 |
| 10 | 11 | 0.672183 | 0.690553 | 0.455240 | 0.793653 |
| 11 | 12 | 0.676808 | 0.691660 | 1.047788 | 0.523842 |
# Plot the training and validation loss
plt.figure(figsize=(9, 5))
values = history_df_restnet['accuracy']
epochs = range(1, len(values) + 1)
plt.plot(epochs, history_df_restnet['loss'], 'bo', label='Training loss')
plt.plot(epochs, history_df_restnet['val_loss'], 'ro', label='Validation loss')
plt.xlabel('Epochs')
plt.xticks(epochs)
plt.ylabel('Loss')
plt.legend()
plt.title('Training and validation loss')
plt.show()
# Plot the training and validation accuracy
plt.figure(figsize=(9, 5))
plt.plot(epochs, history_df_restnet['accuracy'], 'bo', label='Training accuracy')
plt.plot(epochs, history_df_restnet['val_accuracy'], 'ro', label='Validation accuracy')
plt.xlabel('Epochs')
plt.xticks(epochs)
plt.ylabel('Accuracy')
plt.legend()
plt.title('Training and validation accuracy')
plt.show()
model_restnet.evaluate(valid_generator)
1301/1301 [==============================] - 52s 40ms/step - loss: 1.0478 - accuracy: 0.5238
[1.0477871894836426, 0.5238415598869324]
model_restnet.evaluate(test_generator)
1301/1301 [==============================] - 52s 40ms/step - loss: 1.0539 - accuracy: 0.5215
[1.0538926124572754, 0.521511435508728]
best_restnet_model = load_model("./models/finetune-restnet.keras")
best_restnet_model.evaluate(valid_generator)
1301/1301 [==============================] - 53s 40ms/step - loss: 0.4475 - accuracy: 0.7981
[0.44745033979415894, 0.7980734705924988]
best_restnet_model.evaluate(test_generator)
1301/1301 [==============================] - 53s 41ms/step - loss: 0.4410 - accuracy: 0.8045
[0.4410139322280884, 0.8045353293418884]
# predict the model
y_pred_prob_restnet = best_restnet_model.predict(test_generator)
# get the class with the highest probability
y_pred_restnet = np.argmax(y_pred_prob_restnet, axis=1)
# get the true class
y_true = test_generator.classes
y_true_array = np.array(y_true)
# get the class labels
class_labels = list(test_generator.class_indices.keys())
# get the classification report
display(pd.DataFrame(classification_report(y_true, y_pred_restnet, output_dict=True)).T)
# get the confusion matrix
cm_restnet = confusion_matrix(y_true, y_pred_restnet)
# plot the confusion matrix
disp_restnet = ConfusionMatrixDisplay(confusion_matrix=cm_restnet, display_labels=class_labels)
disp_restnet.plot(cmap='Blues')
# get the precision recall curve
postive_class_prob_restnet = y_pred_prob_restnet[:, 1]
precision_restnet, recall_restnet, _ = precision_recall_curve(y_true_array == 1, postive_class_prob_restnet)
# plot the precision recall curve
plt.figure(figsize=(9, 5))
plt.plot(recall_restnet, precision_restnet, "b-", linewidth=2)
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.show()
# get the f1 score
f1_restnet = f1_score(y_true, y_pred_restnet)
print(f'F1 Score: {f1}')
# get the accuracy
accuracy_restnet = accuracy_score(y_true, y_pred_restnet)
print(f'Accuracy: {accuracy_restnet}')
# get the precision
precision_restnet = precision_score(y_true, y_pred_restnet)
print(f'Precision: {precision_restnet}')
# get the recall
recall_restnet = recall_score(y_true, y_pred_restnet)
print(f'Recall: {recall_restnet}')
#from the confusion matrix, calculate tn, fp, fn, tp
tn_restnet, fp_restnet, fn_restnet, tp_restnet = cm.ravel()
print(f'True Negatives: {tn_restnet}')
print(f'False Positives: {fp_restnet}')
print(f'False Negatives: {fn_restnet}')
print(f'True Positives: {tp_restnet}')
# calculate the specificity
specificity_restnet = tn_restnet / (tn_restnet + fp_restnet)
print(f'Specificity: {specificity_restnet}')
1301/1301 [==============================] - 50s 38ms/step
| precision | recall | f1-score | support | |
|---|---|---|---|---|
| 0 | 0.825857 | 0.923028 | 0.871743 | 29959.000000 |
| 1 | 0.716882 | 0.500343 | 0.589352 | 11670.000000 |
| accuracy | 0.804535 | 0.804535 | 0.804535 | 0.804535 |
| macro avg | 0.771369 | 0.711685 | 0.730547 | 41629.000000 |
| weighted avg | 0.795308 | 0.804535 | 0.792579 | 41629.000000 |
F1 Score: 0.6967763885344901 Accuracy: 0.8045352999111196 Precision: 0.7168815224063843 Recall: 0.5003427592116538 True Negatives: 26307 False Positives: 3652 False Negatives: 3478 True Positives: 8192 Specificity: 0.8781000700957976
After evaluating our models on the test datset, some things to notice.
For our Vanilla model, this was our test performance.
display(pd.DataFrame(classification_report(y_true, y_pred, output_dict=True)).T)
| precision | recall | f1-score | support | |
|---|---|---|---|---|
| 0 | 0.883230 | 0.878100 | 0.880657 | 29959.000000 |
| 1 | 0.691658 | 0.701971 | 0.696776 | 11670.000000 |
| accuracy | 0.828725 | 0.828725 | 0.828725 | 0.828725 |
| macro avg | 0.787444 | 0.790035 | 0.788717 | 41629.000000 |
| weighted avg | 0.829526 | 0.828725 | 0.829109 | 41629.000000 |
For the transfer learning with the VGG16 model, we had
display(pd.DataFrame(classification_report(y_true, y_pred_vgg_finetuned, output_dict=True)).T)
| precision | recall | f1-score | support | |
|---|---|---|---|---|
| 0 | 0.867835 | 0.707500 | 0.779508 | 29959.000000 |
| 1 | 0.490671 | 0.723393 | 0.584727 | 11670.000000 |
| accuracy | 0.711956 | 0.711956 | 0.711956 | 0.711956 |
| macro avg | 0.679253 | 0.715447 | 0.682118 | 41629.000000 |
| weighted avg | 0.762103 | 0.711956 | 0.724905 | 41629.000000 |
and then for RestNet50 we had
display(pd.DataFrame(classification_report(y_true, y_pred_restnet, output_dict=True)).T)
| precision | recall | f1-score | support | |
|---|---|---|---|---|
| 0 | 0.825857 | 0.923028 | 0.871743 | 29959.000000 |
| 1 | 0.716882 | 0.500343 | 0.589352 | 11670.000000 |
| accuracy | 0.804535 | 0.804535 | 0.804535 | 0.804535 |
| macro avg | 0.771369 | 0.711685 | 0.730547 | 41629.000000 |
| weighted avg | 0.795308 | 0.804535 | 0.792579 | 41629.000000 |
0.8287¶0.723 but the Vanilla model is not too far behind on 0.702, this is very important.¶0.79.¶When evaluating machine learning models designed for healthcare. The most important metric is the recall score. It is perfectly okay to sacrifice the precision and accuracy score in favour of the recall score.
One way to do this without retaining or modifying our model is to simple reduce the threshold you use in classifying a sample as positive. This way, it becomes harder for a patient that is positive to slip through our prediction. Ensuring that we don't send home a sick patient because our model predicted a false negative.
# Predict class probabilities
y_pred_prob_restnet = model.predict(test_generator)
# Get the true class labels
y_true = test_generator.classes
# Get the probability of the positive class (assuming index 1 is the positive class)
positive_class_prob = y_pred_prob_restnet[:, 1]
# Set the threshold
threshold = 0.4
# Convert probabilities to binary predictions based on the threshold
y_pred_threshold = (positive_class_prob >= threshold).astype(int)
# Calculate recall
recall = recall_score(y_true, y_pred_threshold)
# Print the threshold and its recall
print(f"Threshold: {threshold}")
print(f"Recall: {recall}")
display(pd.DataFrame(classification_report(y_true, y_pred_threshold, output_dict=True)).T)
1301/1301 [==============================] - 27s 21ms/step Threshold: 0.4 Recall: 0.8444730077120822
| precision | recall | f1-score | support | |
|---|---|---|---|---|
| 0 | 0.926153 | 0.759805 | 0.834773 | 29959.00000 |
| 1 | 0.577972 | 0.844473 | 0.686257 | 11670.00000 |
| accuracy | 0.783540 | 0.783540 | 0.783540 | 0.78354 |
| macro avg | 0.752063 | 0.802139 | 0.760515 | 41629.00000 |
| weighted avg | 0.828547 | 0.783540 | 0.793139 | 41629.00000 |
As you can observe above, the recall score for class 1 and the average recall score has now gone up from 0.702 to 84.4 for class 1 and average from 0.79 to 0.802.
You can also see the accuracy, precision and recall score for class 0 have all reduced.
The performance of the transfer/pretrained models were not as good as I would have loved but not too surprising at the same time. These pretrained models were trained on daily life and everyday images, while our dataset is very niche, so they probably couldn't really help.
Trying a different approach for balancing the dataset.
Training for more epoch.
I also observed that when I played around with evaluating my test dataset with the last epoch from my fitting and not the best epoch, it had a better recall score than the best epoch.